{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tune `dolly-v2-7b` with Ray Train, PyTorch Lightning and FSDP\n",
    "\n",
    "<a id=\"try-anyscale-quickstart-dolly_lightning_fsdp_finetuning\" href=\"https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=dolly_lightning_fsdp_finetuning\">\n",
    "    <img src=\"../../../_static/img/run-on-anyscale.svg\" alt=\"try-anyscale-quickstart\">\n",
    "</a>\n",
    "<br></br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we demonstrate how to use Ray Train to fine-tune a [`dolly-v2-7b`](https://huggingface.co/databricks/dolly-v2-7b) model. `dolly-v2-7b` is a 7 billion parameter causal language model created by Databricks, derived from EleutherAI’s [Pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b), and fine-tuned on a [~15K record instruction corpus](https://github.com/databrickslabs/dolly/tree/master/data).\n",
    "\n",
    "We load the pre-trained model from the HuggingFace model hub into a LightningModule and launch an FSDP fine-tuning job across 16 T4 GPUs with the help of {class}`Ray TorchTrainer <ray.train.torch.TorchTrainer>`. It is also straightforward to fine-tune other similar large language models in a similar manner as shown in this example.\n",
    "\n",
    "Before starting this example, we highly recommend reading [Ray Train Key Concepts](train-key-concepts) and [Ray Data Quickstart](data_quickstart)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up ray cluster \n",
    "In this example, we are using a Ray cluster with a `g4dn.8xlarge` head node and 15 `g4dn.4xlarge` worker nodes. Each instance has one Tesla T4 GPU (16GiB Memory). \n",
    "\n",
    "We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image. We tested this example with `lightning==2.0.2` and `transformers==4.29.2`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import ray\n",
    "\n",
    "ray.init(\n",
    "    runtime_env={\n",
    "        \"pip\": [\n",
    "            \"datasets\",\n",
    "            \"evaluate\",\n",
    "            \"transformers>=4.26.0\",\n",
    "            \"torch>=1.12.0\",\n",
    "            \"lightning>=2.0\",\n",
    "        ]\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "MODEL_NAME = \"databricks/dolly-v2-7b\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare your data \n",
    "We are using tiny_shakespeare for fine-tuning, which contains 40,000 lines of Shakespeare from a variety of Shakespeare's plays. Featured in Andrej Karpathy's blog post ['The Unreasonable Effectiveness of Recurrent Neural Networks'](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). \n",
    "\n",
    "Dataset samples:\n",
    "```\n",
    "BAPTISTA:\n",
    "I know him well: you are welcome for his sake.\n",
    "\n",
    "GREMIO:\n",
    "Saving your tale, Petruchio, I pray,\n",
    "Let us, that are poor petitioners, speak too:\n",
    "Baccare! you are marvellous forward.\n",
    "\n",
    "PETRUCHIO:\n",
    "O, pardon me, Signior Gremio; I would fain be doing.\n",
    "```\n",
    "\n",
    "Here, we have adopted similar pre-processing logic from another demo: {doc}`GPT-J-6B Fine-Tuning with Ray Train and DeepSpeed <../deepspeed/gptj_deepspeed_fine_tuning>`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import ray\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "def split_text(batch: pd.DataFrame) -> pd.DataFrame:\n",
    "    text = list(batch[\"text\"])\n",
    "    flat_text = \"\".join(text)\n",
    "    split_text = [\n",
    "        x.strip()\n",
    "        for x in flat_text.split(\"\\n\")\n",
    "        if x.strip() and not x.strip()[-1] == \":\"\n",
    "    ]\n",
    "    return pd.DataFrame(split_text, columns=[\"text\"])\n",
    "\n",
    "\n",
    "def tokenize(batch: pd.DataFrame) -> dict:\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side=\"left\")\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    ret = tokenizer(\n",
    "        list(batch[\"text\"]),\n",
    "        truncation=True,\n",
    "        max_length=256,\n",
    "        padding=\"max_length\",\n",
    "        return_tensors=\"np\",\n",
    "    )\n",
    "    ret[\"labels\"] = ret[\"input_ids\"].copy()\n",
    "    return dict(ret)\n",
    "\n",
    "hf_dataset = load_dataset(\"tiny_shakespeare\", trust_remote_code=True)\n",
    "train_ds = ray.data.from_huggingface(hf_dataset[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first split the original paragraphs into multiple sentences, then tokenize them. Here are some samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-30 11:03:12,182\tINFO dataset.py:2380 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n",
      "2023-08-30 11:03:12,185\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(split_text)] -> LimitOperator[limit=10]\n",
      "2023-08-30 11:03:12,186\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
      "2023-08-30 11:03:12,187\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e398e697cafb4b548fabc85020df5e87",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[{'text': 'Before we proceed any further, hear me speak.'},\n",
       " {'text': 'Speak, speak.'},\n",
       " {'text': 'You are all resolved rather to die than to famish?'},\n",
       " {'text': 'Resolved. resolved.'},\n",
       " {'text': 'First, you know Caius Marcius is chief enemy to the people.'},\n",
       " {'text': \"We know't, we know't.\"},\n",
       " {'text': \"Let us kill him, and we'll have corn at our own price.\"},\n",
       " {'text': \"Is't a verdict?\"},\n",
       " {'text': \"No more talking on't; let it be done: away, away!\"},\n",
       " {'text': 'One word, good citizens.'}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First split the dataset into multiple sentences.\n",
    "train_ds = train_ds.map_batches(split_text, batch_format=\"pandas\")\n",
    "train_ds.take(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Then tokenize the dataset.\n",
    "train_ds = train_ds.map_batches(tokenize, batch_format=\"pandas\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define your lightning model\n",
    "\n",
    "In this example, we use the [dolly-v2-7b](https://huggingface.co/databricks/dolly-v2-7b) model for finetuning. It is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. We load the model weights from Huggingface Model Hub and encapsulate it into a `pl.LightningModule`.\n",
    "\n",
    ":::{note}\n",
    "Make sure you pass the FSDP wrapped model parameters `self.trainer.model.parameters()` into the optimizer, instead of `self.model.parameters()`. \n",
    ":::\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import lightning.pytorch as pl\n",
    "\n",
    "class DollyV2Model(pl.LightningModule):\n",
    "    def __init__(self, lr=2e-5, eps=1e-8):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "        self.lr = lr\n",
    "        self.eps = eps\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n",
    "\n",
    "    def forward(self, batch):\n",
    "        outputs = self.model(\n",
    "            batch[\"input_ids\"], \n",
    "            attention_mask=batch[\"attention_mask\"], \n",
    "            labels=batch[\"labels\"]\n",
    "        )\n",
    "        return outputs.loss\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss = self.forward(batch)\n",
    "        self.log(\"train_loss\", loss, prog_bar=True, on_step=True)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        if self.global_rank == 0:\n",
    "            print(self.trainer.model)\n",
    "        return torch.optim.AdamW(self.trainer.model.parameters(), lr=self.lr, eps=self.eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure your FSDP strategy\n",
    "As `dolly-v2-7b` is a relatively large model, it cannot be properly fit into a single commercial GPU. In this example, we use the FSDP strategy to shard model parameters across multiple workers. This allows us to avoid GPU out-of-memory issues and support a larger global batch size.\n",
    "\n",
    "![](https://user-images.githubusercontent.com/26745457/236892936-d4b91751-4689-421e-ac5f-edfd2eeeb635.png)\n",
    "Image source: [Fully Sharded Data Parallel: faster AI training with fewer GPUs](https://engineering.fb.com/2021/07/15/open-source/fsdp/)\n",
    "\n",
    ":::{note}\n",
    "FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. This was inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. You may refer to these blogs for more information:\n",
    "\n",
    "- [Fully Sharded Data Parallel: faster AI training with fewer GPUs](https://engineering.fb.com/2021/07/15/open-source/fsdp/)\n",
    "- [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#:~:text=FSDP%20is%20a%20type%20of,sizes%20for%20our%20training%20job.)\n",
    "- [PyTorch FSDP Tutorial](https://www.youtube.com/watch?v=8_k76AHu__s&list=PL_lsbAsL_o2BT6aerEKgIoufVD_fodnuT)\n",
    ":::\n",
    "\n",
    "To start training with Lightning's [FSDPStrategy](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html#lightning.pytorch.strategies.FSDPStrategy), you only need to create a {class}`~ray.train.lightning.RayFSDPStrategy` with the same initialization arguments.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import functools\n",
    "import lightning.pytorch as pl \n",
    "\n",
    "from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\n",
    "from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch\n",
    "from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer\n",
    "\n",
    "from ray.train.lightning import RayFSDPStrategy\n",
    "\n",
    "\n",
    "# Define the model sharding policy:\n",
    "# Wrap every GPTNeoXLayer as its own FSDP instance\n",
    "auto_wrap_policy = functools.partial(\n",
    "    transformer_auto_wrap_policy,\n",
    "    transformer_layer_cls = {GPTNeoXLayer}\n",
    ")\n",
    "\n",
    "fsdp_strategy = RayFSDPStrategy(\n",
    "    sharding_strategy=ShardingStrategy.FULL_SHARD,\n",
    "    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,\n",
    "    forward_prefetch=True,\n",
    "    auto_wrap_policy=auto_wrap_policy,\n",
    "    limit_all_gathers=True,\n",
    "    activation_checkpointing=[GPTNeoXLayer],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":::{tip}\n",
    "\n",
    "Some tips for FSDP configuration:\n",
    "- `sharding_strategy`:\n",
    "    - `ShardingStrategy.NO_SHARD`: Parameters, gradients, and optimizer states are not sharded. Similar to DDP.\n",
    "    - `ShardingStrategy.SHARD_GRAD_OP`: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. Similar to ZeRO stage-2.\n",
    "    - `ShardingStrategy.FULL_SHARD`: Parameters, gradients, and optimizer states are sharded. It has minimal GRAM usage among the 3 options. Similar to ZeRO stage-3.\n",
    "- `auto_wrap_policy`:\n",
    "    - Model layers are often wrapped with FSDP in a layered fashion. This means that only the layers in a single FSDP instance are required to aggregate all parameters to a single device during forwarding or backward calculations.\n",
    "    - Use `transformer_auto_wrap_policy` to automatically wrap each Transformer Block into a single FSDP instance. \n",
    "- `backward_prefetch` and `forward_prefetch`:\n",
    "    - Overlap the upcoming all-gather while executing the current forward/backward pass. It can improve throughput but may slightly increase peak memory usage.\n",
    ":::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tune with Ray TorchTrainer\n",
    "\n",
    "Ray TorchTrainer allows you to scale your PyTorch Lightning training workload over multiple nodes. See {ref}`Configuring Scale and GPUs <train_scaling_config>` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_workers = 16\n",
    "batch_size_per_worker = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "# To accelerate release tests\n",
    "train_ds = train_ds.limit(num_workers * batch_size_per_worker * 10)  # each worker has 10 batches"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally, remember to define a Lightning callback that saves and reports checkpoints. Ray Train offers a simple implementation, {meth}`~ray.train.lightning.RayTrainReportCallback`, which persists your checkpoint and metrics in remote storage at the end of each training epoch. \n",
    "\n",
    "Note you can also implement your own report callback with customized logics, such as saving customized checkpoint files or reporting at a different frequency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-30 11:03:16,651\tINFO streaming_executor.py:93 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(split_text)->MapBatches(tokenize)]\n",
      "2023-08-30 11:03:16,652\tINFO streaming_executor.py:94 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
      "2023-08-30 11:03:16,652\tINFO streaming_executor.py:96 -- Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0e66d43f7d44da0a99483390e78e113",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading (…)okenizer_config.json: 100%|██████████| 450/450 [00:00<00:00, 66.9kB/s] \n",
      "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s])\u001b[0m \n",
      "Downloading (…)/main/tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 14.8MB/s]\n",
      "Downloading (…)cial_tokens_map.json: 100%|██████████| 228/228 [00:00<00:00, 143kB/s]m \n"
     ]
    }
   ],
   "source": [
    "from lightning.pytorch.callbacks import TQDMProgressBar\n",
    "\n",
    "# Create a customized progress bar for Ray Data Iterable Dataset\n",
    "class DollyV2ProgressBar(TQDMProgressBar):\n",
    "    def __init__(self, num_iters_per_epoch, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.num_iters_per_epoch = num_iters_per_epoch\n",
    "    \n",
    "    def on_train_epoch_start(self, trainer, *_):\n",
    "        super().on_train_epoch_start(trainer, *_)\n",
    "        self.train_progress_bar.reset(self.num_iters_per_epoch)\n",
    "\n",
    "num_iters_per_epoch = train_ds.count() // (num_workers * batch_size_per_worker)\n",
    "prog_bar = DollyV2ProgressBar(num_iters_per_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.train import Checkpoint\n",
    "from ray.train.lightning import RayLightningEnvironment, RayTrainReportCallback, prepare_trainer\n",
    "\n",
    "# Training function for each worker\n",
    "def train_func(config):\n",
    "    lr = config[\"lr\"]\n",
    "    eps = config[\"eps\"]\n",
    "    strategy = config[\"strategy\"]\n",
    "    batch_size_per_worker = config[\"batch_size_per_worker\"]\n",
    "\n",
    "    # Model\n",
    "    model = DollyV2Model(lr=lr, eps=eps)\n",
    "\n",
    "    # Ray Data Ingestion\n",
    "    train_ds = ray.train.get_dataset_shard(\"train\")\n",
    "    train_dataloader = train_ds.iter_torch_batches(batch_size=batch_size_per_worker)\n",
    "\n",
    "    # Lightning Trainer\n",
    "    trainer = pl.Trainer(\n",
    "        max_epochs=1, \n",
    "        devices=\"auto\",\n",
    "        accelerator=\"auto\", \n",
    "        precision=\"16-mixed\",\n",
    "        strategy=strategy,\n",
    "        plugins=[RayLightningEnvironment()],\n",
    "        callbacks=[RayTrainReportCallback()],\n",
    "        enable_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = prepare_trainer(trainer)\n",
    "\n",
    "    trainer.fit(model, train_dataloaders=train_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "Since this example runs with multiple nodes, we need to persist checkpoints\n",
    "and other outputs to some external storage for access after training has completed.\n",
    "**You should set up cloud storage or NFS, then replace `storage_path` with your own cloud bucket URI or NFS path.**\n",
    "\n",
    "See the [storage guide](tune-storage-options) for more details.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "storage_path=\"s3://your-bucket-here\"  # TODO: Set up cloud storage\n",
    "# storage_path=\"/mnt/path/to/nfs\"     # TODO: Alternatively, set up NFS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "storage_path = \"/mnt/cluster_storage\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2023-08-30 11:51:22</td></tr>\n",
       "<tr><td>Running for: </td><td>00:47:57.19        </td></tr>\n",
       "<tr><td>Memory:      </td><td>39.3/124.3 GiB     </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using FIFO scheduling algorithm.<br>Logical resource usage: 193.0/272 CPUs, 16.0/16 GPUs (0.0/16.0 accelerator_type:None)\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name              </th><th>status    </th><th>loc              </th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  train_loss</th><th style=\"text-align: right;\">  epoch</th><th style=\"text-align: right;\">  step</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>TorchTrainer_839b5_00000</td><td>TERMINATED</td><td>10.0.23.226:66074</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         2868.15</td><td style=\"text-align: right;\">    0.176025</td><td style=\"text-align: right;\">      0</td><td style=\"text-align: right;\">   135</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m StorageContext on SESSION (rank=None):\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m StorageContext<\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   storage_path=/mnt/cluster_storage\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   storage_local_path=/home/ray/ray_results\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7f8b0a2cd7f0>\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   storage_fs_path=/mnt/cluster_storage\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   experiment_dir_name=finetune_dolly-v2-7b\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   trial_dir_name=TorchTrainer_839b5_00000_0_2023-08-30_11-03-25\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m   current_checkpoint_index=0\n",
      "\u001b[2m\u001b[36m(TrainTrainable pid=66074)\u001b[0m >\n",
      "\u001b[2m\u001b[36m(TorchTrainer pid=66074)\u001b[0m Starting distributed worker processes: ['66181 (10.0.23.226)', '14250 (10.0.40.16)', '13932 (10.0.2.17)', '13832 (10.0.41.56)', '14288 (10.0.53.250)', '13909 (10.0.41.152)', '13803 (10.0.14.94)', '47214 (10.0.44.99)', '13836 (10.0.58.27)', '13838 (10.0.58.206)', '13755 (10.0.62.244)', '13828 (10.0.9.99)', '13771 (10.0.43.35)', '13726 (10.0.59.245)', '13829 (10.0.58.178)', '13861 (10.0.46.116)']\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m Setting up process group for: env:// [rank=0, world_size=16]\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m StorageContext on SESSION (rank=0):\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m StorageContext<\u001b[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m   storage_path=/mnt/cluster_storage\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m   storage_local_path=/home/ray/ray_results\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m   storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7ff145c40b30>\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m   storage_fs_path=/mnt/cluster_storage\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m   current_checkpoint_index=0\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=14250, ip=10.0.40.16)\u001b[0m >\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(SplitCoordinator pid=66292)\u001b[0m Auto configuring locality_with_output=['5ed99d043a52f67deb150f34202c09b77bd37409502ebf6e581b0544', '8efe8198d7c04d45714ae757f298c316117405f3a8b25b87a71e0d9e', 'e3754d1e1017e68dd919b35d35ea62ed7b005ad96452f371721fc9fa', '8bd0f431ab3733c4b423c1d50db06460e3c210de47355b3b4d215c31', '73a8b9377fe9531a84eaa7b30c966fbb11bc36aff070d55c8f7acd1a', 'ef922c93f3b2fc93ebe5a521426d24fb8aae7e13c65f9fbd106aea2a', '5249cff3eab41121f840c17a79e6a3cd0af0f059def707a39e055fcf', '042b668e5553a589a4f6693c45deee0abe57a1d754812172af425acb', '9ed138bfe1f9c7dca484ee08d8311806389adb3af7a76566a6f4dfaa', '7e2fcb5dfe4ab1b572d87257f9e13bbc22b33ba968b1e67a79505589', '39b1ef4da8493a22e321a1ea9dd13387f50d9a6e2d2fbad58ad5fe9c', '9484193409a5346c0838a4a19a0a08eec122477682ea1cb0ad3e305a', '0158084645ec305bdd2ab11a6f35c44ad206405ca810e65f24b09398', 'fe5b11633900d1c437b2e3ee4ea44c18cf68f3dece546537d2090c63', '573645f42162f531a66d20776a95ba05102fae8e4b8090d48b94b233', '47e317ad5d0eb94cabb78871541160763283629d0d3f3b77b69521ae']\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m Using 16bit Automatic Mixed Precision (AMP)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m GPU available: True (cuda), used: True\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m TPU available: False, using: 0 TPU cores\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m IPU available: False, using: 0 IPUs\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m HPU available: False, using: 0 HPUs\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m StorageContext on SESSION (rank=13):\u001b[32m [repeated 15x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m StorageContext<\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m   storage_path=/mnt/cluster_storage\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m   storage_local_path=/home/ray/ray_results\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m   storage_filesystem=<pyarrow._fs.LocalFileSystem object at 0x7fc9841d7a70>\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m   storage_fs_path=/mnt/cluster_storage\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m   current_checkpoint_index=0\u001b[32m [repeated 42x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m >\u001b[32m [repeated 14x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/lightning_logs\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13832, ip=10.0.41.56)\u001b[0m Using 16bit Automatic Mixed Precision (AMP)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13836, ip=10.0.58.27)\u001b[0m Using 16bit Automatic Mixed Precision (AMP)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13832, ip=10.0.41.56)\u001b[0m Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/lightning_logs\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m Using 16bit Automatic Mixed Precision (AMP)\u001b[32m [repeated 13x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=47214, ip=10.0.44.99)\u001b[0m Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/lightning_logs\u001b[32m [repeated 13x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13909, ip=10.0.41.152)\u001b[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m FullyShardedDataParallel(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m   (_fsdp_wrapped_module): _LightningModuleWrapperBase(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m     (_forward_module): DollyV2Model(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m       (model): GPTNeoXForCausalLM(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m         (gpt_neox): GPTNeoXModel(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m           (embed_in): Embedding(50280, 4096)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m           (layers): ModuleList(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m             (0-31): 32 x FullyShardedDataParallel(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m               (_fsdp_wrapped_module): CheckpointWrapper(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                 (_checkpoint_wrapped_module): GPTNeoXLayer(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   (attention): GPTNeoXAttention(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (rotary_emb): RotaryEmbedding()\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (dense): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   (mlp): GPTNeoXMLP(\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                     (act): GELUActivation()\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                   )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m                 )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m               )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m             )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m           )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m           (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m         )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m         (embed_out): Linear(in_features=4096, out_features=50280, bias=False)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m       )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m     )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m   )\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m )\n",
      "Epoch 0:   0%|          | 0/134 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m \n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m   | Name  | Type               | Params\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m ---------------------------------------------\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m 0 | model | GPTNeoXForCausalLM | 402 M \n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m ---------------------------------------------\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m 402 M     Trainable params\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m 0         Non-trainable params\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m 402 M     Total params\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m 1,611.039 Total estimated model params size (MB)\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13726, ip=10.0.59.245)\u001b[0m Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/lightning_logs\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13803, ip=10.0.14.94)\u001b[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\u001b[32m [repeated 15x across cluster]\u001b[0m\n",
      "\u001b[2m\u001b[36m(SplitCoordinator pid=66292)\u001b[0m Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(split_text)->MapBatches(tokenize)] -> OutputSplitter[split(16, equal=True)]\n",
      "\u001b[2m\u001b[36m(SplitCoordinator pid=66292)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['5ed99d043a52f67deb150f34202c09b77bd37409502ebf6e581b0544', '8efe8198d7c04d45714ae757f298c316117405f3a8b25b87a71e0d9e', 'e3754d1e1017e68dd919b35d35ea62ed7b005ad96452f371721fc9fa', '8bd0f431ab3733c4b423c1d50db06460e3c210de47355b3b4d215c31', '73a8b9377fe9531a84eaa7b30c966fbb11bc36aff070d55c8f7acd1a', 'ef922c93f3b2fc93ebe5a521426d24fb8aae7e13c65f9fbd106aea2a', '5249cff3eab41121f840c17a79e6a3cd0af0f059def707a39e055fcf', '042b668e5553a589a4f6693c45deee0abe57a1d754812172af425acb', '9ed138bfe1f9c7dca484ee08d8311806389adb3af7a76566a6f4dfaa', '7e2fcb5dfe4ab1b572d87257f9e13bbc22b33ba968b1e67a79505589', '39b1ef4da8493a22e321a1ea9dd13387f50d9a6e2d2fbad58ad5fe9c', '9484193409a5346c0838a4a19a0a08eec122477682ea1cb0ad3e305a', '0158084645ec305bdd2ab11a6f35c44ad206405ca810e65f24b09398', 'fe5b11633900d1c437b2e3ee4ea44c18cf68f3dece546537d2090c63', '573645f42162f531a66d20776a95ba05102fae8e4b8090d48b94b233', '47e317ad5d0eb94cabb78871541160763283629d0d3f3b77b69521ae'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
      "\u001b[2m\u001b[36m(SplitCoordinator pid=66292)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb9d86f824ce4098bc4f081b87d47da4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=66292) Running 0:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "186888f1b0ff4ed687052fe0bc3d1c8a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=66292) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 126.62 MiB/1.86 GiB object_store_memory 0:   0%|          | …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "810e2b512ec048a0a53e6546dfec49a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=66292) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 126.62 MiB/1.86 GiB object_store_memory 0:   0%|          | …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:   1%|          | 1/134 [00:25<57:20, 25.87s/it, v_num=0, train_loss=12.90]\n",
      "Epoch 0:   1%|▏         | 2/134 [00:43<47:24, 21.55s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   2%|▏         | 3/134 [01:00<43:55, 20.12s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   3%|▎         | 4/134 [01:21<44:23, 20.49s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   4%|▎         | 5/134 [01:39<42:42, 19.87s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   4%|▍         | 6/134 [01:56<41:29, 19.45s/it, v_num=0, train_loss=12.50]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +4m7s)\u001b[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +4m7s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 448.90 KB).\n",
      "Epoch 0:   5%|▌         | 7/134 [02:13<40:30, 19.14s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   6%|▌         | 8/134 [02:31<39:43, 18.92s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   7%|▋         | 9/134 [02:48<39:04, 18.76s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   7%|▋         | 9/134 [02:48<39:06, 18.77s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   7%|▋         | 10/134 [03:06<38:33, 18.66s/it, v_num=0, train_loss=12.50]\n",
      "Epoch 0:   7%|▋         | 10/134 [03:06<38:35, 18.67s/it, v_num=0, train_loss=0.587]\n",
      "Epoch 0:   8%|▊         | 11/134 [03:24<38:02, 18.56s/it, v_num=0, train_loss=0.587]\n",
      "Epoch 0:   8%|▊         | 11/134 [03:24<38:04, 18.57s/it, v_num=0, train_loss=0.600]\n",
      "Epoch 0:   9%|▉         | 12/134 [03:41<37:34, 18.48s/it, v_num=0, train_loss=0.600]\n",
      "Epoch 0:   9%|▉         | 12/134 [03:41<37:36, 18.49s/it, v_num=0, train_loss=0.590]\n",
      "Epoch 0:  10%|▉         | 13/134 [03:59<37:08, 18.41s/it, v_num=0, train_loss=0.590]\n",
      "Epoch 0:  10%|▉         | 13/134 [03:59<37:09, 18.43s/it, v_num=0, train_loss=0.591]\n",
      "Epoch 0:  10%|█         | 14/134 [04:17<36:44, 18.37s/it, v_num=0, train_loss=0.591]\n",
      "Epoch 0:  10%|█         | 14/134 [04:17<36:45, 18.38s/it, v_num=0, train_loss=0.590]\n",
      "Epoch 0:  11%|█         | 15/134 [04:34<36:19, 18.32s/it, v_num=0, train_loss=0.590]\n",
      "Epoch 0:  11%|█         | 15/134 [04:34<36:21, 18.33s/it, v_num=0, train_loss=0.551]\n",
      "Epoch 0:  12%|█▏        | 16/134 [04:52<35:57, 18.29s/it, v_num=0, train_loss=0.551]\n",
      "Epoch 0:  12%|█▏        | 16/134 [04:52<35:58, 18.29s/it, v_num=0, train_loss=0.521]\n",
      "Epoch 0:  13%|█▎        | 17/134 [05:10<35:35, 18.25s/it, v_num=0, train_loss=0.521]\n",
      "Epoch 0:  13%|█▎        | 17/134 [05:10<35:36, 18.26s/it, v_num=0, train_loss=0.522]\n",
      "Epoch 0:  13%|█▎        | 18/134 [05:28<35:14, 18.23s/it, v_num=0, train_loss=0.522]\n",
      "Epoch 0:  13%|█▎        | 18/134 [05:28<35:15, 18.23s/it, v_num=0, train_loss=0.518]\n",
      "Epoch 0:  14%|█▍        | 19/134 [05:45<34:52, 18.19s/it, v_num=0, train_loss=0.518]\n",
      "Epoch 0:  14%|█▍        | 19/134 [05:45<34:52, 18.20s/it, v_num=0, train_loss=0.476]\n",
      "Epoch 0:  15%|█▍        | 20/134 [06:03<34:30, 18.16s/it, v_num=0, train_loss=0.476]\n",
      "Epoch 0:  15%|█▍        | 20/134 [06:03<34:31, 18.17s/it, v_num=0, train_loss=0.457]\n",
      "Epoch 0:  16%|█▌        | 21/134 [06:23<34:23, 18.26s/it, v_num=0, train_loss=0.457]\n",
      "Epoch 0:  16%|█▌        | 21/134 [06:23<34:23, 18.27s/it, v_num=0, train_loss=0.476]\n",
      "Epoch 0:  16%|█▋        | 22/134 [06:42<34:09, 18.30s/it, v_num=0, train_loss=0.476]\n",
      "Epoch 0:  16%|█▋        | 22/134 [06:42<34:10, 18.31s/it, v_num=0, train_loss=0.447]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +9m7s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 451.39 KB).\n",
      "Epoch 0:  17%|█▋        | 23/134 [07:00<33:49, 18.28s/it, v_num=0, train_loss=0.447]\n",
      "Epoch 0:  17%|█▋        | 23/134 [07:00<33:49, 18.29s/it, v_num=0, train_loss=0.412]\n",
      "Epoch 0:  18%|█▊        | 24/134 [07:19<33:33, 18.31s/it, v_num=0, train_loss=0.412]\n",
      "Epoch 0:  18%|█▊        | 24/134 [07:19<33:34, 18.31s/it, v_num=0, train_loss=0.385]\n",
      "Epoch 0:  19%|█▊        | 25/134 [07:37<33:13, 18.29s/it, v_num=0, train_loss=0.385]\n",
      "Epoch 0:  19%|█▊        | 25/134 [07:37<33:13, 18.29s/it, v_num=0, train_loss=0.384]\n",
      "Epoch 0:  19%|█▉        | 26/134 [07:54<32:52, 18.27s/it, v_num=0, train_loss=0.384]\n",
      "Epoch 0:  19%|█▉        | 26/134 [07:55<32:53, 18.27s/it, v_num=0, train_loss=0.406]\n",
      "Epoch 0:  20%|██        | 27/134 [08:12<32:32, 18.25s/it, v_num=0, train_loss=0.406]\n",
      "Epoch 0:  20%|██        | 27/134 [08:12<32:32, 18.25s/it, v_num=0, train_loss=0.380]\n",
      "Epoch 0:  21%|██        | 28/134 [08:30<32:12, 18.23s/it, v_num=0, train_loss=0.380]\n",
      "Epoch 0:  21%|██        | 28/134 [08:30<32:12, 18.23s/it, v_num=0, train_loss=0.405]\n",
      "Epoch 0:  22%|██▏       | 29/134 [08:48<31:52, 18.21s/it, v_num=0, train_loss=0.405]\n",
      "Epoch 0:  22%|██▏       | 29/134 [08:48<31:52, 18.22s/it, v_num=0, train_loss=0.355]\n",
      "Epoch 0:  22%|██▏       | 30/134 [09:06<31:32, 18.20s/it, v_num=0, train_loss=0.355]\n",
      "Epoch 0:  22%|██▏       | 30/134 [09:06<31:33, 18.21s/it, v_num=0, train_loss=0.376]\n",
      "Epoch 0:  23%|██▎       | 31/134 [09:23<31:12, 18.18s/it, v_num=0, train_loss=0.376]\n",
      "Epoch 0:  23%|██▎       | 31/134 [09:23<31:13, 18.19s/it, v_num=0, train_loss=0.330]\n",
      "Epoch 0:  24%|██▍       | 32/134 [09:41<30:53, 18.17s/it, v_num=0, train_loss=0.330]\n",
      "Epoch 0:  24%|██▍       | 32/134 [09:41<30:53, 18.17s/it, v_num=0, train_loss=0.359]\n",
      "Epoch 0:  25%|██▍       | 33/134 [09:59<30:33, 18.15s/it, v_num=0, train_loss=0.359]\n",
      "Epoch 0:  25%|██▍       | 33/134 [09:59<30:33, 18.16s/it, v_num=0, train_loss=0.319]\n",
      "Epoch 0:  25%|██▌       | 34/134 [10:16<30:14, 18.14s/it, v_num=0, train_loss=0.319]\n",
      "Epoch 0:  25%|██▌       | 34/134 [10:16<30:14, 18.15s/it, v_num=0, train_loss=0.359]\n",
      "Epoch 0:  26%|██▌       | 35/134 [10:34<29:54, 18.13s/it, v_num=0, train_loss=0.359]\n",
      "Epoch 0:  26%|██▌       | 35/134 [10:34<29:54, 18.13s/it, v_num=0, train_loss=0.405]\n",
      "Epoch 0:  27%|██▋       | 36/134 [10:52<29:35, 18.12s/it, v_num=0, train_loss=0.405]\n",
      "Epoch 0:  27%|██▋       | 36/134 [10:52<29:35, 18.12s/it, v_num=0, train_loss=0.362]\n",
      "Epoch 0:  28%|██▊       | 37/134 [11:09<29:16, 18.10s/it, v_num=0, train_loss=0.362]\n",
      "Epoch 0:  28%|██▊       | 37/134 [11:10<29:16, 18.11s/it, v_num=0, train_loss=0.343]\n",
      "Epoch 0:  28%|██▊       | 38/134 [11:27<28:56, 18.09s/it, v_num=0, train_loss=0.343]\n",
      "Epoch 0:  28%|██▊       | 38/134 [11:27<28:57, 18.10s/it, v_num=0, train_loss=0.335]\n",
      "Epoch 0:  29%|██▉       | 39/134 [11:45<28:38, 18.08s/it, v_num=0, train_loss=0.335]\n",
      "Epoch 0:  29%|██▉       | 39/134 [11:45<28:38, 18.09s/it, v_num=0, train_loss=0.325]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +14m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 455.47 KB).\n",
      "Epoch 0:  30%|██▉       | 40/134 [12:03<28:19, 18.08s/it, v_num=0, train_loss=0.325]\n",
      "Epoch 0:  30%|██▉       | 40/134 [12:03<28:19, 18.08s/it, v_num=0, train_loss=0.344]\n",
      "Epoch 0:  31%|███       | 41/134 [12:20<27:59, 18.06s/it, v_num=0, train_loss=0.344]\n",
      "Epoch 0:  31%|███       | 41/134 [12:20<28:00, 18.06s/it, v_num=0, train_loss=0.312]\n",
      "Epoch 0:  31%|███▏      | 42/134 [12:38<27:40, 18.05s/it, v_num=0, train_loss=0.312]\n",
      "Epoch 0:  31%|███▏      | 42/134 [12:38<27:40, 18.05s/it, v_num=0, train_loss=0.338]\n",
      "Epoch 0:  32%|███▏      | 43/134 [12:55<27:21, 18.04s/it, v_num=0, train_loss=0.338]\n",
      "Epoch 0:  32%|███▏      | 43/134 [12:55<27:21, 18.04s/it, v_num=0, train_loss=0.315]\n",
      "Epoch 0:  33%|███▎      | 44/134 [13:13<27:02, 18.03s/it, v_num=0, train_loss=0.315]\n",
      "Epoch 0:  33%|███▎      | 44/134 [13:13<27:02, 18.03s/it, v_num=0, train_loss=0.330]\n",
      "Epoch 0:  34%|███▎      | 45/134 [13:31<26:44, 18.02s/it, v_num=0, train_loss=0.330]\n",
      "Epoch 0:  34%|███▎      | 45/134 [13:31<26:44, 18.03s/it, v_num=0, train_loss=0.253]\n",
      "Epoch 0:  34%|███▍      | 46/134 [13:48<26:25, 18.02s/it, v_num=0, train_loss=0.253]\n",
      "Epoch 0:  34%|███▍      | 46/134 [13:49<26:25, 18.02s/it, v_num=0, train_loss=0.310]\n",
      "Epoch 0:  35%|███▌      | 47/134 [14:06<26:07, 18.01s/it, v_num=0, train_loss=0.310]\n",
      "Epoch 0:  35%|███▌      | 47/134 [14:06<26:07, 18.01s/it, v_num=0, train_loss=0.294]\n",
      "Epoch 0:  36%|███▌      | 48/134 [14:24<25:48, 18.01s/it, v_num=0, train_loss=0.294]\n",
      "Epoch 0:  36%|███▌      | 48/134 [14:24<25:48, 18.01s/it, v_num=0, train_loss=0.302]\n",
      "Epoch 0:  37%|███▋      | 49/134 [14:41<25:29, 18.00s/it, v_num=0, train_loss=0.302]\n",
      "Epoch 0:  37%|███▋      | 49/134 [14:42<25:30, 18.00s/it, v_num=0, train_loss=0.325]\n",
      "Epoch 0:  37%|███▋      | 50/134 [14:59<25:11, 17.99s/it, v_num=0, train_loss=0.325]\n",
      "Epoch 0:  37%|███▋      | 50/134 [14:59<25:11, 18.00s/it, v_num=0, train_loss=0.250]\n",
      "Epoch 0:  38%|███▊      | 51/134 [15:17<24:53, 17.99s/it, v_num=0, train_loss=0.250]\n",
      "Epoch 0:  38%|███▊      | 51/134 [15:17<24:53, 17.99s/it, v_num=0, train_loss=0.291]\n",
      "Epoch 0:  39%|███▉      | 52/134 [15:34<24:34, 17.98s/it, v_num=0, train_loss=0.291]\n",
      "Epoch 0:  39%|███▉      | 52/134 [15:35<24:34, 17.98s/it, v_num=0, train_loss=0.261]\n",
      "Epoch 0:  40%|███▉      | 53/134 [15:52<24:15, 17.98s/it, v_num=0, train_loss=0.261]\n",
      "Epoch 0:  40%|███▉      | 53/134 [15:52<24:16, 17.98s/it, v_num=0, train_loss=0.292]\n",
      "Epoch 0:  40%|████      | 54/134 [16:10<23:57, 17.97s/it, v_num=0, train_loss=0.292]\n",
      "Epoch 0:  40%|████      | 54/134 [16:10<23:58, 17.98s/it, v_num=0, train_loss=0.245]\n",
      "Epoch 0:  41%|████      | 55/134 [16:28<23:39, 17.97s/it, v_num=0, train_loss=0.245]\n",
      "Epoch 0:  41%|████      | 55/134 [16:28<23:39, 17.97s/it, v_num=0, train_loss=0.265]\n",
      "Epoch 0:  42%|████▏     | 56/134 [16:45<23:20, 17.96s/it, v_num=0, train_loss=0.265]\n",
      "Epoch 0:  42%|████▏     | 56/134 [16:45<23:21, 17.96s/it, v_num=0, train_loss=0.233]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +19m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 458.23 KB).\n",
      "Epoch 0:  43%|████▎     | 57/134 [17:03<23:02, 17.96s/it, v_num=0, train_loss=0.233]\n",
      "Epoch 0:  43%|████▎     | 57/134 [17:03<23:02, 17.96s/it, v_num=0, train_loss=0.228]\n",
      "Epoch 0:  43%|████▎     | 58/134 [17:21<22:44, 17.96s/it, v_num=0, train_loss=0.228]\n",
      "Epoch 0:  43%|████▎     | 58/134 [17:21<22:45, 17.96s/it, v_num=0, train_loss=0.222]\n",
      "Epoch 0:  44%|████▍     | 59/134 [17:39<22:26, 17.96s/it, v_num=0, train_loss=0.222]\n",
      "Epoch 0:  44%|████▍     | 59/134 [17:39<22:27, 17.96s/it, v_num=0, train_loss=0.240]\n",
      "Epoch 0:  45%|████▍     | 60/134 [17:57<22:08, 17.96s/it, v_num=0, train_loss=0.240]\n",
      "Epoch 0:  45%|████▍     | 60/134 [17:57<22:08, 17.96s/it, v_num=0, train_loss=0.220]\n",
      "Epoch 0:  46%|████▌     | 61/134 [18:15<21:50, 17.95s/it, v_num=0, train_loss=0.220]\n",
      "Epoch 0:  46%|████▌     | 61/134 [18:15<21:50, 17.95s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  46%|████▋     | 62/134 [18:32<21:32, 17.95s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  46%|████▋     | 62/134 [18:32<21:32, 17.95s/it, v_num=0, train_loss=0.230]\n",
      "Epoch 0:  47%|████▋     | 63/134 [18:50<21:14, 17.95s/it, v_num=0, train_loss=0.230]\n",
      "Epoch 0:  47%|████▋     | 63/134 [18:50<21:14, 17.95s/it, v_num=0, train_loss=0.247]\n",
      "Epoch 0:  48%|████▊     | 64/134 [19:08<20:55, 17.94s/it, v_num=0, train_loss=0.247]\n",
      "Epoch 0:  48%|████▊     | 64/134 [19:08<20:56, 17.94s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  49%|████▊     | 65/134 [19:25<20:37, 17.94s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  49%|████▊     | 65/134 [19:26<20:37, 17.94s/it, v_num=0, train_loss=0.233]\n",
      "Epoch 0:  49%|████▉     | 66/134 [19:43<20:19, 17.93s/it, v_num=0, train_loss=0.233]\n",
      "Epoch 0:  49%|████▉     | 66/134 [19:43<20:19, 17.94s/it, v_num=0, train_loss=0.253]\n",
      "Epoch 0:  50%|█████     | 67/134 [20:01<20:01, 17.93s/it, v_num=0, train_loss=0.253]\n",
      "Epoch 0:  50%|█████     | 67/134 [20:01<20:01, 17.93s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  51%|█████     | 68/134 [20:19<19:43, 17.93s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  51%|█████     | 68/134 [20:19<19:43, 17.93s/it, v_num=0, train_loss=0.270]\n",
      "Epoch 0:  51%|█████▏    | 69/134 [20:37<19:25, 17.93s/it, v_num=0, train_loss=0.270]\n",
      "Epoch 0:  51%|█████▏    | 69/134 [20:37<19:25, 17.93s/it, v_num=0, train_loss=0.220]\n",
      "Epoch 0:  52%|█████▏    | 70/134 [20:54<19:07, 17.93s/it, v_num=0, train_loss=0.220]\n",
      "Epoch 0:  52%|█████▏    | 70/134 [20:55<19:07, 17.93s/it, v_num=0, train_loss=0.249]\n",
      "Epoch 0:  53%|█████▎    | 71/134 [21:12<18:49, 17.93s/it, v_num=0, train_loss=0.249]\n",
      "Epoch 0:  53%|█████▎    | 71/134 [21:12<18:49, 17.93s/it, v_num=0, train_loss=0.231]\n",
      "Epoch 0:  54%|█████▎    | 72/134 [21:30<18:31, 17.92s/it, v_num=0, train_loss=0.231]\n",
      "Epoch 0:  54%|█████▎    | 72/134 [21:30<18:31, 17.92s/it, v_num=0, train_loss=0.206]\n",
      "Epoch 0:  54%|█████▍    | 73/134 [21:48<18:13, 17.92s/it, v_num=0, train_loss=0.206]\n",
      "Epoch 0:  54%|█████▍    | 73/134 [21:48<18:13, 17.92s/it, v_num=0, train_loss=0.266]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +24m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 462.71 KB).\n",
      "Epoch 0:  55%|█████▌    | 74/134 [22:05<17:55, 17.92s/it, v_num=0, train_loss=0.266]\n",
      "Epoch 0:  55%|█████▌    | 74/134 [22:06<17:55, 17.92s/it, v_num=0, train_loss=0.252]\n",
      "Epoch 0:  56%|█████▌    | 75/134 [22:23<17:37, 17.92s/it, v_num=0, train_loss=0.252]\n",
      "Epoch 0:  56%|█████▌    | 75/134 [22:23<17:37, 17.92s/it, v_num=0, train_loss=0.218]\n",
      "Epoch 0:  57%|█████▋    | 76/134 [22:41<17:19, 17.92s/it, v_num=0, train_loss=0.218]\n",
      "Epoch 0:  57%|█████▋    | 76/134 [22:41<17:19, 17.92s/it, v_num=0, train_loss=0.195]\n",
      "Epoch 0:  57%|█████▋    | 77/134 [22:59<17:01, 17.91s/it, v_num=0, train_loss=0.195]\n",
      "Epoch 0:  57%|█████▋    | 77/134 [22:59<17:01, 17.91s/it, v_num=0, train_loss=0.210]\n",
      "Epoch 0:  58%|█████▊    | 78/134 [23:17<16:42, 17.91s/it, v_num=0, train_loss=0.210]\n",
      "Epoch 0:  58%|█████▊    | 78/134 [23:17<16:43, 17.91s/it, v_num=0, train_loss=0.198]\n",
      "Epoch 0:  59%|█████▉    | 79/134 [23:34<16:24, 17.91s/it, v_num=0, train_loss=0.198]\n",
      "Epoch 0:  59%|█████▉    | 79/134 [23:34<16:24, 17.91s/it, v_num=0, train_loss=0.232]\n",
      "Epoch 0:  60%|█████▉    | 80/134 [23:52<16:06, 17.90s/it, v_num=0, train_loss=0.232]\n",
      "Epoch 0:  60%|█████▉    | 80/134 [23:52<16:06, 17.90s/it, v_num=0, train_loss=0.267]\n",
      "Epoch 0:  60%|██████    | 81/134 [24:09<15:48, 17.90s/it, v_num=0, train_loss=0.267]\n",
      "Epoch 0:  60%|██████    | 81/134 [24:10<15:48, 17.90s/it, v_num=0, train_loss=0.244]\n",
      "Epoch 0:  61%|██████    | 82/134 [24:27<15:30, 17.90s/it, v_num=0, train_loss=0.244]\n",
      "Epoch 0:  61%|██████    | 82/134 [24:27<15:30, 17.90s/it, v_num=0, train_loss=0.173]\n",
      "Epoch 0:  62%|██████▏   | 83/134 [24:45<15:12, 17.89s/it, v_num=0, train_loss=0.173]\n",
      "Epoch 0:  62%|██████▏   | 83/134 [24:45<15:12, 17.89s/it, v_num=0, train_loss=0.225]\n",
      "Epoch 0:  63%|██████▎   | 84/134 [25:02<14:54, 17.89s/it, v_num=0, train_loss=0.225]\n",
      "Epoch 0:  63%|██████▎   | 84/134 [25:03<14:54, 17.89s/it, v_num=0, train_loss=0.231]\n",
      "Epoch 0:  63%|██████▎   | 85/134 [25:20<14:36, 17.89s/it, v_num=0, train_loss=0.231]\n",
      "Epoch 0:  63%|██████▎   | 85/134 [25:20<14:36, 17.89s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  64%|██████▍   | 86/134 [25:38<14:18, 17.88s/it, v_num=0, train_loss=0.235]\n",
      "Epoch 0:  64%|██████▍   | 86/134 [25:38<14:18, 17.89s/it, v_num=0, train_loss=0.257]\n",
      "Epoch 0:  65%|██████▍   | 87/134 [25:55<14:00, 17.88s/it, v_num=0, train_loss=0.257]\n",
      "Epoch 0:  65%|██████▍   | 87/134 [25:56<14:00, 17.89s/it, v_num=0, train_loss=0.297]\n",
      "Epoch 0:  66%|██████▌   | 88/134 [26:13<13:42, 17.88s/it, v_num=0, train_loss=0.297]\n",
      "Epoch 0:  66%|██████▌   | 88/134 [26:13<13:42, 17.88s/it, v_num=0, train_loss=0.277]\n",
      "Epoch 0:  66%|██████▋   | 89/134 [26:31<13:24, 17.88s/it, v_num=0, train_loss=0.277]\n",
      "Epoch 0:  66%|██████▋   | 89/134 [26:31<13:24, 17.88s/it, v_num=0, train_loss=0.264]\n",
      "Epoch 0:  67%|██████▋   | 90/134 [26:49<13:06, 17.88s/it, v_num=0, train_loss=0.264]\n",
      "Epoch 0:  67%|██████▋   | 90/134 [26:49<13:06, 17.88s/it, v_num=0, train_loss=0.269]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +29m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 465.81 KB).\n",
      "Epoch 0:  68%|██████▊   | 91/134 [27:06<12:48, 17.88s/it, v_num=0, train_loss=0.269]\n",
      "Epoch 0:  68%|██████▊   | 91/134 [27:07<12:48, 17.88s/it, v_num=0, train_loss=0.268]\n",
      "Epoch 0:  69%|██████▊   | 92/134 [27:24<12:30, 17.88s/it, v_num=0, train_loss=0.268]\n",
      "Epoch 0:  69%|██████▊   | 92/134 [27:24<12:30, 17.88s/it, v_num=0, train_loss=0.200]\n",
      "Epoch 0:  69%|██████▉   | 93/134 [27:42<12:12, 17.87s/it, v_num=0, train_loss=0.200]\n",
      "Epoch 0:  69%|██████▉   | 93/134 [27:42<12:12, 17.88s/it, v_num=0, train_loss=0.229]\n",
      "Epoch 0:  70%|███████   | 94/134 [28:00<11:54, 17.87s/it, v_num=0, train_loss=0.229]\n",
      "Epoch 0:  70%|███████   | 94/134 [28:00<11:54, 17.87s/it, v_num=0, train_loss=0.248]\n",
      "Epoch 0:  71%|███████   | 95/134 [28:17<11:37, 17.87s/it, v_num=0, train_loss=0.248]\n",
      "Epoch 0:  71%|███████   | 95/134 [28:18<11:37, 17.87s/it, v_num=0, train_loss=0.217]\n",
      "Epoch 0:  72%|███████▏  | 96/134 [28:35<11:19, 17.87s/it, v_num=0, train_loss=0.217]\n",
      "Epoch 0:  72%|███████▏  | 96/134 [28:35<11:19, 17.87s/it, v_num=0, train_loss=0.217]\n",
      "Epoch 0:  72%|███████▏  | 97/134 [28:53<11:01, 17.87s/it, v_num=0, train_loss=0.217]\n",
      "Epoch 0:  72%|███████▏  | 97/134 [28:53<11:01, 17.87s/it, v_num=0, train_loss=0.238]\n",
      "Epoch 0:  73%|███████▎  | 98/134 [29:11<10:43, 17.87s/it, v_num=0, train_loss=0.238]\n",
      "Epoch 0:  73%|███████▎  | 98/134 [29:11<10:43, 17.87s/it, v_num=0, train_loss=0.168]\n",
      "Epoch 0:  74%|███████▍  | 99/134 [29:28<10:25, 17.87s/it, v_num=0, train_loss=0.168]\n",
      "Epoch 0:  74%|███████▍  | 99/134 [29:29<10:25, 17.87s/it, v_num=0, train_loss=0.198]\n",
      "Epoch 0:  75%|███████▍  | 100/134 [29:46<10:07, 17.87s/it, v_num=0, train_loss=0.198]\n",
      "Epoch 0:  75%|███████▍  | 100/134 [29:46<10:07, 17.87s/it, v_num=0, train_loss=0.205]\n",
      "Epoch 0:  75%|███████▌  | 101/134 [30:04<09:49, 17.86s/it, v_num=0, train_loss=0.205]\n",
      "Epoch 0:  75%|███████▌  | 101/134 [30:04<09:49, 17.87s/it, v_num=0, train_loss=0.165]\n",
      "Epoch 0:  76%|███████▌  | 102/134 [30:21<09:31, 17.86s/it, v_num=0, train_loss=0.165]\n",
      "Epoch 0:  76%|███████▌  | 102/134 [30:22<09:31, 17.86s/it, v_num=0, train_loss=0.261]\n",
      "Epoch 0:  77%|███████▋  | 103/134 [30:39<09:13, 17.86s/it, v_num=0, train_loss=0.261]\n",
      "Epoch 0:  77%|███████▋  | 103/134 [30:39<09:13, 17.86s/it, v_num=0, train_loss=0.250]\n",
      "Epoch 0:  78%|███████▊  | 104/134 [30:57<08:55, 17.86s/it, v_num=0, train_loss=0.250]\n",
      "Epoch 0:  78%|███████▊  | 104/134 [30:57<08:55, 17.86s/it, v_num=0, train_loss=0.161]\n",
      "Epoch 0:  78%|███████▊  | 105/134 [31:15<08:37, 17.86s/it, v_num=0, train_loss=0.161]\n",
      "Epoch 0:  78%|███████▊  | 105/134 [31:15<08:38, 17.86s/it, v_num=0, train_loss=0.202]\n",
      "Epoch 0:  79%|███████▉  | 106/134 [31:33<08:20, 17.86s/it, v_num=0, train_loss=0.202]\n",
      "Epoch 0:  79%|███████▉  | 106/134 [31:33<08:20, 17.86s/it, v_num=0, train_loss=0.177]\n",
      "Epoch 0:  80%|███████▉  | 107/134 [31:50<08:02, 17.86s/it, v_num=0, train_loss=0.177]\n",
      "Epoch 0:  80%|███████▉  | 107/134 [31:51<08:02, 17.86s/it, v_num=0, train_loss=0.225]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +34m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 470.30 KB).\n",
      "Epoch 0:  81%|████████  | 108/134 [32:08<07:44, 17.86s/it, v_num=0, train_loss=0.225]\n",
      "Epoch 0:  81%|████████  | 108/134 [32:08<07:44, 17.86s/it, v_num=0, train_loss=0.188]\n",
      "Epoch 0:  81%|████████▏ | 109/134 [32:26<07:26, 17.86s/it, v_num=0, train_loss=0.188]\n",
      "Epoch 0:  81%|████████▏ | 109/134 [32:26<07:26, 17.86s/it, v_num=0, train_loss=0.205]\n",
      "Epoch 0:  82%|████████▏ | 110/134 [32:44<07:08, 17.86s/it, v_num=0, train_loss=0.205]\n",
      "Epoch 0:  82%|████████▏ | 110/134 [32:44<07:08, 17.86s/it, v_num=0, train_loss=0.218]\n",
      "Epoch 0:  83%|████████▎ | 111/134 [33:02<06:50, 17.86s/it, v_num=0, train_loss=0.218]\n",
      "Epoch 0:  83%|████████▎ | 111/134 [33:02<06:50, 17.86s/it, v_num=0, train_loss=0.259]\n",
      "Epoch 0:  84%|████████▎ | 112/134 [33:20<06:32, 17.86s/it, v_num=0, train_loss=0.259]\n",
      "Epoch 0:  84%|████████▎ | 112/134 [33:20<06:32, 17.86s/it, v_num=0, train_loss=0.255]\n",
      "Epoch 0:  84%|████████▍ | 113/134 [33:38<06:15, 17.86s/it, v_num=0, train_loss=0.255]\n",
      "Epoch 0:  84%|████████▍ | 113/134 [33:38<06:15, 17.86s/it, v_num=0, train_loss=0.221]\n",
      "Epoch 0:  85%|████████▌ | 114/134 [33:55<05:57, 17.86s/it, v_num=0, train_loss=0.221]\n",
      "Epoch 0:  85%|████████▌ | 114/134 [33:56<05:57, 17.86s/it, v_num=0, train_loss=0.185]\n",
      "Epoch 0:  86%|████████▌ | 115/134 [34:13<05:39, 17.86s/it, v_num=0, train_loss=0.185]\n",
      "Epoch 0:  86%|████████▌ | 115/134 [34:13<05:39, 17.86s/it, v_num=0, train_loss=0.189]\n",
      "Epoch 0:  87%|████████▋ | 116/134 [34:31<05:21, 17.86s/it, v_num=0, train_loss=0.189]\n",
      "Epoch 0:  87%|████████▋ | 116/134 [34:31<05:21, 17.86s/it, v_num=0, train_loss=0.168]\n",
      "Epoch 0:  87%|████████▋ | 117/134 [34:48<05:03, 17.85s/it, v_num=0, train_loss=0.168]\n",
      "Epoch 0:  87%|████████▋ | 117/134 [34:49<05:03, 17.86s/it, v_num=0, train_loss=0.166]\n",
      "Epoch 0:  88%|████████▊ | 118/134 [35:06<04:45, 17.85s/it, v_num=0, train_loss=0.166]\n",
      "Epoch 0:  88%|████████▊ | 118/134 [35:06<04:45, 17.86s/it, v_num=0, train_loss=0.184]\n",
      "Epoch 0:  89%|████████▉ | 119/134 [35:24<04:27, 17.85s/it, v_num=0, train_loss=0.184]\n",
      "Epoch 0:  89%|████████▉ | 119/134 [35:24<04:27, 17.86s/it, v_num=0, train_loss=0.230]\n",
      "Epoch 0:  90%|████████▉ | 120/134 [35:42<04:09, 17.85s/it, v_num=0, train_loss=0.230]\n",
      "Epoch 0:  90%|████████▉ | 120/134 [35:42<04:09, 17.86s/it, v_num=0, train_loss=0.251]\n",
      "Epoch 0:  90%|█████████ | 121/134 [36:00<03:52, 17.86s/it, v_num=0, train_loss=0.251]\n",
      "Epoch 0:  90%|█████████ | 121/134 [36:00<03:52, 17.86s/it, v_num=0, train_loss=0.244]\n",
      "Epoch 0:  91%|█████████ | 122/134 [36:18<03:34, 17.85s/it, v_num=0, train_loss=0.244]\n",
      "Epoch 0:  91%|█████████ | 122/134 [36:18<03:34, 17.86s/it, v_num=0, train_loss=0.232]\n",
      "Epoch 0:  92%|█████████▏| 123/134 [36:35<03:16, 17.85s/it, v_num=0, train_loss=0.232]\n",
      "Epoch 0:  92%|█████████▏| 123/134 [36:36<03:16, 17.85s/it, v_num=0, train_loss=0.200]\n",
      "Epoch 0:  93%|█████████▎| 124/134 [36:53<02:58, 17.85s/it, v_num=0, train_loss=0.200]\n",
      "Epoch 0:  93%|█████████▎| 124/134 [36:53<02:58, 17.85s/it, v_num=0, train_loss=0.152]\n",
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +39m8s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 473.57 KB).\n",
      "Epoch 0:  93%|█████████▎| 125/134 [37:11<02:40, 17.86s/it, v_num=0, train_loss=0.152]\n",
      "Epoch 0:  93%|█████████▎| 125/134 [37:12<02:40, 17.86s/it, v_num=0, train_loss=0.162]\n",
      "Epoch 0:  94%|█████████▍| 126/134 [37:29<02:22, 17.86s/it, v_num=0, train_loss=0.162]\n",
      "Epoch 0:  94%|█████████▍| 126/134 [37:29<02:22, 17.86s/it, v_num=0, train_loss=0.184]\n",
      "Epoch 0:  95%|█████████▍| 127/134 [37:47<02:04, 17.86s/it, v_num=0, train_loss=0.184]\n",
      "Epoch 0:  95%|█████████▍| 127/134 [37:47<02:04, 17.86s/it, v_num=0, train_loss=0.186]\n",
      "Epoch 0:  96%|█████████▌| 128/134 [38:05<01:47, 17.85s/it, v_num=0, train_loss=0.186]\n",
      "Epoch 0:  96%|█████████▌| 128/134 [38:05<01:47, 17.86s/it, v_num=0, train_loss=0.208]\n",
      "Epoch 0:  96%|█████████▋| 129/134 [38:23<01:29, 17.85s/it, v_num=0, train_loss=0.208]\n",
      "Epoch 0:  96%|█████████▋| 129/134 [38:23<01:29, 17.85s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  97%|█████████▋| 130/134 [38:40<01:11, 17.85s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  97%|█████████▋| 130/134 [38:41<01:11, 17.85s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  98%|█████████▊| 131/134 [38:58<00:53, 17.85s/it, v_num=0, train_loss=0.243]\n",
      "Epoch 0:  98%|█████████▊| 131/134 [38:58<00:53, 17.85s/it, v_num=0, train_loss=0.213]\n",
      "Epoch 0:  99%|█████████▊| 132/134 [39:16<00:35, 17.85s/it, v_num=0, train_loss=0.213]\n",
      "Epoch 0:  99%|█████████▊| 132/134 [39:16<00:35, 17.85s/it, v_num=0, train_loss=0.214]\n",
      "Epoch 0:  99%|█████████▉| 133/134 [39:34<00:17, 17.85s/it, v_num=0, train_loss=0.214]\n",
      "Epoch 0:  99%|█████████▉| 133/134 [39:34<00:17, 17.85s/it, v_num=0, train_loss=0.182]\n",
      "Epoch 0: 100%|██████████| 134/134 [39:52<00:00, 17.85s/it, v_num=0, train_loss=0.182]\n",
      "Epoch 0: 100%|██████████| 134/134 [39:52<00:00, 17.85s/it, v_num=0, train_loss=0.174]\n",
      "Epoch 0: : 135it [40:10, 17.85s/it, v_num=0, train_loss=0.174]                       \n",
      "Epoch 0: : 135it [40:10, 17.85s/it, v_num=0, train_loss=0.176]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RayTrainWorker pid=13861, ip=10.0.46.116)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/checkpoint_000000)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[1m\u001b[36m(autoscaler +44m9s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 477.63 KB).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/checkpoint_000000)\u001b[32m [repeated 15x across cluster]\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: : 135it [46:03, 20.47s/it, v_num=0, train_loss=0.176]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m `Trainer.fit` stopped: `max_epochs=1` reached.\n",
      "\u001b[2m\u001b[36m(RayTrainWorker pid=66181)\u001b[0m RayFSDPStrategy: tearing down strategy...\n",
      "2023-08-30 11:51:22,248\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n",
      "2023-08-30 11:51:22,253\tINFO tune.py:1142 -- Total run time: 2877.24 seconds (2877.14 seconds for the tuning loop).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Result(\n",
       "  metrics={'train_loss': 0.176025390625, 'epoch': 0, 'step': 135},\n",
       "  path='/mnt/cluster_storage/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25',\n",
       "  filesystem='local',\n",
       "  checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune_dolly-v2-7b/TorchTrainer_839b5_00000_0_2023-08-30_11-03-25/checkpoint_000000)\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ray.train.torch import TorchTrainer\n",
    "from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n",
    "\n",
    "# Save Ray Train checkpoints according to the performance on validation set\n",
    "run_config = RunConfig(\n",
    "    name=\"finetune_dolly-v2-7b\",\n",
    "    storage_path=storage_path,\n",
    "    checkpoint_config=CheckpointConfig(num_to_keep=1),\n",
    ")\n",
    "\n",
    "# Scale the FSDP training workload across 16 GPUs\n",
    "# You can change this config based on your compute resources.\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=num_workers, use_gpu=True, trainer_resources={\"memory\": 100 * 1024 ** 3}\n",
    ")\n",
    "\n",
    "# Configuration to pass into train_func\n",
    "train_config = {\n",
    "    \"lr\": 2e-5,\n",
    "    \"eps\": 1e-8,\n",
    "    \"strategy\": fsdp_strategy,\n",
    "    \"batch_size_per_worker\": 10\n",
    "}\n",
    "\n",
    "# Define a TorchTrainer and launch you training workload\n",
    "ray_trainer = TorchTrainer(\n",
    "    train_func,\n",
    "    train_loop_config=train_config,\n",
    "    run_config=run_config,\n",
    "    scaling_config=scaling_config,\n",
    "    datasets={\"train\": train_ds},\n",
    ")\n",
    "result = ray_trainer.fit()\n",
    "\n",
    "result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We finished training in 2877s. The price for an on-demand g4dn.4xlarge instance is `$1.204/hour`, while a g4dn.8xlarge instance costs `$2.176/hour`. The total cost would be `($1.204 * 15 + $2.176) * 2877 / 3600 = $16.17`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text-generation with HuggingFace Pipeline\n",
    "\n",
    "We can use the [HuggingFace Pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) to generate predictions from our fine-tuned model. Let's input some prompts and see if our tuned Dolly can speak like Shakespeare:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from transformers import pipeline\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side=\"right\")\n",
    "\n",
    "ckpt_path = os.path.join(result.checkpoint.path, \"checkpoint.ckpt\")\n",
    "\n",
    "dolly = DollyV2Model.load_from_checkpoint(ckpt_path, map_location=torch.device(\"cpu\"))\n",
    "\n",
    "nlp_pipeline = pipeline(\n",
    "    task=\"text-generation\", \n",
    "    model=dolly.model, \n",
    "    tokenizer=tokenizer, \n",
    "    device_map=\"auto\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'generated_text': \"This is the day that our hearts live to love. Now come, go in; I'll sit here:\"}]\n",
      "[{'generated_text': 'I am very sorry, not a jot. What would you have? your pardon? my good lord?'}]\n",
      "[{'generated_text': 'Once more, look up, look up, my sovereign; look up this night!'}]\n"
     ]
    }
   ],
   "source": [
    "for prompt in [\"This is\", \"I am\", \"Once more\"]:\n",
    "    print(nlp_pipeline(prompt, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "References:\n",
    "- [PyTorch FSDP Tutorial](https://www.youtube.com/watch?v=8_k76AHu__s&list=PL_lsbAsL_o2BT6aerEKgIoufVD_fodnuT)\n",
    "- [Getting Started with Fully Sharded Data Parallel(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#:~:text=FSDP%20is%20a%20type%20of,sizes%20for%20our%20training%20job.)\n",
    "- [Fully Sharded Data Parallel: faster AI training with fewer GPUs](https://engineering.fb.com/2021/07/15/open-source/fsdp/)\n",
    "- [Hugging Face: dolly-v2-7b Model Card](https://huggingface.co/databricks/dolly-v2-7b)\n",
    "- [Hugging Face: Handling big models for inference](https://huggingface.co/docs/accelerate/usage_guides/big_modeling)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
